{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Does nn.Conv2d init work well?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Jump_to lesson 9 video](https://course19.fast.ai/videos/?lesson=9&t=21)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from exp.nb_02 import *\n",
    "\n",
    "def get_data():\n",
    "    path = datasets.download_data(MNIST_URL, ext='.gz')\n",
    "    with gzip.open(path, 'rb') as f:\n",
    "        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')\n",
    "    return map(tensor, (x_train,y_train,x_valid,y_valid))\n",
    "\n",
    "def normalize(x, m, s): return (x-m)/s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.nn.modules.conv._ConvNd.reset_parameters??"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train,y_train,x_valid,y_valid = get_data()\n",
    "train_mean,train_std = x_train.mean(),x_train.std()\n",
    "x_train = normalize(x_train, train_mean, train_std)\n",
    "x_valid = normalize(x_valid, train_mean, train_std)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([50000, 1, 28, 28]), torch.Size([10000, 1, 28, 28]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_train = x_train.view(-1,1,28,28)\n",
    "x_valid = x_valid.view(-1,1,28,28)\n",
    "x_train.shape,x_valid.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(50000, tensor(10))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n,*_ = x_train.shape\n",
    "c = y_train.max()+1\n",
    "nh = 32\n",
    "n,c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "l1 = nn.Conv2d(1, nh, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = x_valid[:100]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([100, 1, 28, 28])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def stats(x): return x.mean(),x.std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([32, 1, 5, 5])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "l1.weight.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((tensor(-0.0043, grad_fn=<MeanBackward1>),\n",
       "  tensor(0.1156, grad_fn=<StdBackward0>)),\n",
       " (tensor(0.0212, grad_fn=<MeanBackward1>),\n",
       "  tensor(0.1176, grad_fn=<StdBackward0>)))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "stats(l1.weight),stats(l1.bias)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = l1(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(0.0107, grad_fn=<MeanBackward1>),\n",
       " tensor(0.5978, grad_fn=<StdBackward0>))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "stats(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(0.0267, grad_fn=<MeanBackward1>),\n",
       " tensor(1.1067, grad_fn=<StdBackward0>))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "init.kaiming_normal_(l1.weight, a=1.)\n",
    "stats(l1(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f1(x,a=0): return F.leaky_relu(l1(x),a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(0.5547, grad_fn=<MeanBackward1>),\n",
       " tensor(1.0199, grad_fn=<StdBackward0>))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "init.kaiming_normal_(l1.weight, a=0)\n",
    "stats(f1(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(0.2219, grad_fn=<MeanBackward1>),\n",
       " tensor(0.3653, grad_fn=<StdBackward0>))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "l1 = nn.Conv2d(1, nh, 5)\n",
    "stats(f1(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([32, 1, 5, 5])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "l1.weight.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "25"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# receptive field size\n",
    "rec_fs = l1.weight[0,0].numel()\n",
    "rec_fs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(32, 1)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nf,ni,*_ = l1.weight.shape\n",
    "nf,ni"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(25, 800)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fan_in  = ni*rec_fs\n",
    "fan_out = nf*rec_fs\n",
    "fan_in,fan_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gain(a): return math.sqrt(2.0 / (1 + a**2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1.0,\n",
       " 1.4142135623730951,\n",
       " 1.4141428569978354,\n",
       " 1.4071950894605838,\n",
       " 0.5773502691896257)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gain(1),gain(0),gain(0.01),gain(0.1),gain(math.sqrt(5.))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.5788)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.zeros(10000).uniform_(-1,1).std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5773502691896258"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "1/math.sqrt(3.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def kaiming2(x,a, use_fan_out=False):\n",
    "    nf,ni,*_ = x.shape\n",
    "    rec_fs = x[0,0].shape.numel()\n",
    "    fan = nf*rec_fs if use_fan_out else ni*rec_fs\n",
    "    std = gain(a) / math.sqrt(fan)\n",
    "    bound = math.sqrt(3.) * std\n",
    "    x.data.uniform_(-bound,bound)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(0.5603, grad_fn=<MeanBackward1>),\n",
       " tensor(1.0921, grad_fn=<StdBackward0>))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "kaiming2(l1.weight, a=0);\n",
    "stats(f1(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(0.2186, grad_fn=<MeanBackward1>),\n",
       " tensor(0.3437, grad_fn=<StdBackward0>))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "kaiming2(l1.weight, a=math.sqrt(5.))\n",
    "stats(f1(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Flatten(nn.Module):\n",
    "    def forward(self,x): return x.view(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = nn.Sequential(\n",
    "    nn.Conv2d(1,8, 5,stride=2,padding=2), nn.ReLU(),\n",
    "    nn.Conv2d(8,16,3,stride=2,padding=1), nn.ReLU(),\n",
    "    nn.Conv2d(16,32,3,stride=2,padding=1), nn.ReLU(),\n",
    "    nn.Conv2d(32,1,3,stride=2,padding=1),\n",
    "    nn.AdaptiveAvgPool2d(1),\n",
    "    Flatten(),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = y_valid[:100].float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(0.0875, grad_fn=<MeanBackward1>),\n",
       " tensor(0.0065, grad_fn=<StdBackward0>))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = m(x)\n",
    "stats(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "l = mse(t,y)\n",
    "l.backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(0.0054), tensor(0.0333))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "stats(m[0].weight.grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init.kaiming_uniform_??"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for l in m:\n",
    "    if isinstance(l,nn.Conv2d):\n",
    "        init.kaiming_uniform_(l.weight)\n",
    "        l.bias.data.zero_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(-0.0352, grad_fn=<MeanBackward1>),\n",
       " tensor(0.4043, grad_fn=<StdBackward0>))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = m(x)\n",
    "stats(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(0.0093), tensor(0.4231))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "l = mse(t,y)\n",
    "l.backward()\n",
    "stats(m[0].weight.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!./notebook2script.py 02a_why_sqrt5.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
